package com.cicadasmall.mybatis.interceptor;

import cn.hutool.core.util.ClassUtil;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.parser.SQLExprParser;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.util.JdbcUtils;
import com.cicadasmall.common.base.LoginUser;
import com.cicadasmall.common.utils.SecurityUtils;
import com.cicadasmall.mybatis.annotation.DataScope;
import com.cicadasmall.common.func.Fn;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

import java.util.ArrayList;
import java.util.List;
import java.util.Properties;


/**
 * DataScoreInterceptor 未实际测试
 *
 * @author Jin
 */
@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
})
public class DataScoreInterceptor implements Interceptor {

    final static int MAPPED_STATEMENT_INDEX = 0;
    final static int PARAMETER_INDEX = 1;

    @Override
    public Object intercept(Invocation invocation) throws InvocationTargetException, IllegalAccessException {
        Object[] args = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) args[MAPPED_STATEMENT_INDEX];
        if (Fn.notEqual(SqlCommandType.SELECT, mappedStatement.getSqlCommandType())) {
            return invocation.proceed();
        }

        DataScope dataScope = getAnnotation(mappedStatement);
        if (Fn.isNull(dataScope)) {
            return invocation.proceed();
        }

        final Object parameter = args[PARAMETER_INDEX];
        final BoundSql boundSql = mappedStatement.getBoundSql(parameter);

        String newSql = appendDataScopeCondition(boundSql.getSql(), dataScope);

        if (Fn.isEmpty(newSql)) {
            return invocation.proceed();
        }

        BoundSql newBoundSql = new BoundSql(mappedStatement.getConfiguration(), newSql, boundSql.getParameterMappings(), boundSql.getParameterObject());

        MappedStatement newMappedStatement = copyFromMappedStatement(mappedStatement, new BoundSqlSqlSource(newBoundSql));
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }
        args[MAPPED_STATEMENT_INDEX] = newMappedStatement;

        return invocation.proceed();
    }

    private String appendDataScopeCondition(String sql, DataScope annotation) {
        SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, JdbcUtils.MYSQL);
        List<SQLStatement> stmtList = parser.parseStatementList();
        SQLStatement stmt = stmtList.get(0);

        if (stmt instanceof SQLSelectStatement) {

            StringBuilder stringBuilder = new StringBuilder();
            LoginUser loginUser = SecurityUtils.getCurrentLoginUser();
            if (Fn.isNull(loginUser)) {
                return null;
            }

            List<String> dataScopes = new ArrayList<>(loginUser.getDataScopes());
            if (Fn.isNotEmpty(loginUser.getDataScopes())) {
                for (int i = 0; i < dataScopes.size(); i++) {
                    String dataScope = dataScopes.get(i);
                    stringBuilder.append(" ");
                    if (Fn.isNotEmpty(annotation.alias())) {
                        stringBuilder.append(annotation.alias())
                                .append(".");
                    }
                    stringBuilder.append(Fn.isNotEmpty(annotation.fieldName()) ? annotation.fieldName() : "create_org")
                            .append("IN (")
                            .append(dataScope)
                            .append(")")
                            .append(" ");
                    if (i < dataScopes.size() - 1) {
                        stringBuilder.append(" AND ");
                    }
                }
            }

            SQLExprParser constraintsParser = SQLParserUtils.createExprParser(stringBuilder.toString(), JdbcUtils.MYSQL);
            SQLExpr constraintsExpr = constraintsParser.expr();

            SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;

            SQLSelect sqlselect = selectStmt.getSelect();
            SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
            SQLExpr whereExpr = query.getWhere();

            if (whereExpr == null) {
                query.setWhere(constraintsExpr);
            } else {
                SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
                query.setWhere(newWhereExpr);
            }
            sqlselect.setQuery(query);
            return sqlselect.toString();
        }
        return sql;
    }


    private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    private DataScope getAnnotation(MappedStatement mappedStatement) {
        try {
            String id = mappedStatement.getId();
            String className = id.substring(0, id.lastIndexOf("."));
            String methodName = id.substring(id.lastIndexOf(".") + 1);
            final Class<?> clazz = Class.forName(className);
            final Method method = ClassUtil.getPublicMethod(clazz, methodName);
            if (Fn.isNotNull(method) && method.isAnnotationPresent(DataScope.class)) {
                return method.getAnnotation(DataScope.class);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {
        // to do nothing
    }
}
